{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import time\n",
    "\n",
    "# init variable a,b as 1000 dimension vector\n",
    "n=1000\n",
    "a=torch.ones(n)\n",
    "b=torch.ones(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a  timer class to record time\n",
    "class Timer(object):\n",
    "    \"\"\"Record multiple running times.\"\"\"\n",
    "    def __init__(self):\n",
    "        self.times=[]\n",
    "        self.start()\n",
    "   \n",
    "    def start(self):\n",
    "        # start the timer\n",
    "        self.start_time=time.time()\n",
    "        \n",
    "    def stop(self):\n",
    "        # stop the timer and record time into a list\n",
    "        self.times.append(time.time()-self.start_time)\n",
    "        return self.times[-1]\n",
    "    \n",
    "    def avg(self):\n",
    "        # calculate the average and return\n",
    "        return sum(self.times)/len(self.times)\n",
    "    \n",
    "    def sum(self):\n",
    "        # return the sum of recorded time\n",
    "        return sum(self.times)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "向量加法的比较  \n",
    "- 标量加法：按元素逐一相加\n",
    "- 矢量加法：将两个向量直接做矢量加法  \n",
    "\n",
    "首先将两个向量使用for循环按元素逐一做标量加法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.02900 seconds'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "timer=Timer()\n",
    "c=torch.zeros(n)\n",
    "for i in range(n):\n",
    "    c[i]=a[i]+b[i]\n",
    "\"%.5f seconds\"%timer.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后使用torch直接对两个向量做矢量加法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.00100 seconds'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "timer.start()\n",
    "d=a+b\n",
    "\"%.5f seconds\"%timer.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 线性回归模型从零开始的实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import packages and modules\n",
    "import torch\n",
    "from IPython import display\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "# jupyter :内嵌绘图\n",
    "%matplotlib inline "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 生成数据集\n",
    "使用linear module生成数据集，1000个样本  \n",
    "\n",
    "$$price=w_{area}*area+w_{age}*age+b$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输入特征数量\n",
    "NUM_INPUTS=2\n",
    "# 样本数量\n",
    "NUM_EXAMPELS=1000\n",
    "\n",
    "#设置真实参数,权重矩阵W[]，偏置bias，用来生成数据集的真实标签\n",
    "w=[2,-3.4]\n",
    "b=4.2\n",
    "\n",
    "features=torch.randn(NUM_EXAMPELS,NUM_INPUTS,dtype=torch.float32)\n",
    "labels=w[0]*features[:,0]+w[1]*features[:,1]+b\n",
    "# labels加入一点点噪声，噪声数据符合均值为0，宽度（标准差）为0.01的正态分布\n",
    "labels+=torch.tensor(np.random.normal(0,0.01,size=labels.size()),dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用图像来展示生成的数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x203dd789e88>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD8CAYAAABq6S8VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2df5BdZZnnvw8JrfmBkE6aGJJASDqEbdiY0RZZfiiQZEt3qKi7E2rUctCZNVq1WhmkLPzBrrjDlNYooZidKTGuOGHX1QnOOFIZLZNQ0RACSocKGWjSSTc0JJJKOt0BSTrFTTfv/nHve/u9733f97znnPfce869z6eq6yb3nvue55x7zvM+5/n1khACDMMwTHE5r9kCMAzDMOlgRc4wDFNwWJEzDMMUHFbkDMMwBYcVOcMwTMFhRc4wDFNwvBU5ES0mol1E9AIRPU9EGyvvdxLRDiI6XHmdk524DMMwjA755pET0QIAC4QQzxDRBQD2AfgIgE8BGBNCfIuIvgxgjhDirqwEZhiGYWrxtsiFEMeEEM9U/v0GgBcALATwYQBbKpttQVm5MwzDMA3C2yKv+RLREgC7AVwN4BUhxEXKZ6eEEE73yrx588SSJUti75dhGKad2bdv30khRJf+/vS4AxHRbAD/BOAvhRB/ICLf720AsAEALr30UvT19cXdNcMwTFtDRC+b3o+VtUJE56OsxH8khPjnytvHK/5z6Uc/YfquEGKzEKJXCNHb1VU3oTAMwzAJiZO1QgB+AOAFIcQm5aNHAdxe+fftAH4eTjyGYRgmijiulesBfBLAvxHR/sp7XwXwLQBbiegvALwCYH1YERmGYRgX3opcCLEHgM0hvjqMOAzDMExcuLKTYRim4LAiZxiGKTisyBmGYQoOK3KGYZrK2JkSvvebIYydKTVblMLCipxhmKbySN8RfPOXB/FI35Fmi1JYYld2MgzDhGR97+KaVyY+rMgZhmkqnbM68NkPLGu2GIWGXSsMw2QC+74bBytyhmEygX3fjYNdKwzDZAL7vhsHK3KGYTKBfd+Ng10rDMMwBYcVOcMwTMFhRc4wDFNwWJEzDMMUHFbkDMMwBYcVOcMwTMFhRc4wDFNwWJEzDMMUHFbkDMMwBYcVOcMwTMFhRc4wDFNwWJEzTAHhFrGMCityhikgvi1ifRU+TwzFhrsfMkwB8W0RKxU+AGcnQt/tmHzCipxhCohvi1hfhc+9w4sNu1YYJsekdXlIhd85qyPIdkw+YUXOMDmmmculsd+8OHi7VojoIQC3AjghhLi68t49AD4DYKSy2VeFEL8ILSTDtCvNdHmw37w4xPGR/wOAvwPwsPb+/UKI7wSTiGGYKs1cLo395sXB27UihNgNYCxDWRiGyRHN9JuzWyceIXzknyeiA0T0EBHNsW1ERBuIqI+I+kZGRmybMQzDNDU2EJc8TDppFfl3ASwDsArAMQD32TYUQmwWQvQKIXq7urpS7pZhmFZmfe9ifOVDVzbNrRNHOedh0kmVRy6EOC7/TUTfB7AttUQMw7Q9zYwNAPECvXmIJaRS5ES0QAhxrPLfjwJ4Lr1IDMO0GmNnSnik7wjW9y4uRK56HOXc7EkHiJd++GMANwGYR0RHAXwdwE1EtAqAADAM4LMZyMgwTMEpWipjHpRzHLwVuRDiY4a3fxBQFoZhWpQ8uB9aGa7sZJgMSZvRkIeMiBBwC4BsYUXOMBmSNqMhDxkRrYxroizSJMrdDxkmQ9K6FNglkS0u332R/PqsyBkmQ9IGzYoWdCsaromySJMoCSEavtPe3l7R19fX8P0yDMMUGSLaJ4To1d9nHznDFJAi+W+Z7GFFzjAFhIOgjAorcoZxkFfLt9m9SNqRvF4LACtyhnESerX6tN+RcF62nawUbp6fgjhrhWEchF6tPu13mGiyOq95zmJhRc4wDkKvVp/2O3kir42wXOc1jcx5TgVl1wrTcjTDl5nE1VF090heXQ2u85pXmdPCipxpOVr1Zg2BnOSGRk7HmuxMk+Oanvm4eUUX1vTMz0rc4LRqkJhdK0zL0UyXRV7dDRI5yT314ih2DZSXXPRxF5j8zjv7j2PXwAiuXXocyz4wOzuhA5Jn90gaWJEzLUczb9a8BjDlBCOt5zU983Ht0uPek51pciy6jz8NeZuwWZEzTEDyqtxME0wcK9o0ObaqdetD3iZsVuRMHXmzNopEXpVbXieYopK388nBTqYODhaGI1QGTdpx1Akmr9WJRSJvGUesyJk6WjWyH4dQCjhUZWioybUIk3SeS+HzCitypg7d2sjbjeUjT1qZQyk83xS9qP2FmlxDT9JZXBtFmGzyBvvImUjyFtjxkSetzKF8oFEpeno2iW1/Jt97kliGPk7aeEgW10Yz/M9FjwuxImciyVtgx0eetDKHClqu6ZmPp14ctVrkuiIcO1PC/TsGABBuv26JU6mEUKJ5mfBUmhEwzpuxEhdW5EwkecvE8JEnLzJHWeS6Inyk7wgeeGwQADCzYxrW9y62WoohlKg6RggLv6jkzViJCytyhsmQKAWhK8L1vYsxXpoAQFUlbrMUQyhRPZulyFZpGoo+IbEiZ5gMiasgOmd14I61K6r/z9pSVK3wolulURTdD+6Cs1YYJiVRmRu+mR2m7XzzlZNmj6gZIlnlRucl66mVs2HYImdyRxzLKQ9WVlSgTH4uXSaAwO3XXV4nr20cn2NMGqxrhBXejECi6Zy18hMHK3Imd8S58fOQbRClIOT746VJPPDYYQDAzI7pdfLq40hlNF6aqAZAbceYNGhpc/2EnCDTKtAkspiui6L7wV14K3IiegjArQBOCCGurrzXCeAfASwBMAzgNiHEqfBiMu1EnBs/D1ZWlIKQn0+5FoRRXn0cqYw2rl4eWcTjE7SMoxBDTpBpFWgSWfJwXTQSEkL4bUj0fgCnATysKPK/ATAmhPgWEX0ZwBwhxF1RY/X29oq+vr4UYjNMMUhj2Sb9rlpktLP/ePX7UsF/5UNXRirEPLis8ihLsyGifUKIXv1972CnEGI3gDHt7Q8D2FL59xYAH0ksIcPkiDQBSpU0AbakwUf5vZ39x2v2Hac8P09NofIkS15Jm7UyXwhxDAAqrxfbNiSiDUTUR0R9IyMjKXfL5IG8ZCNEUa6WPIT7dwx4y+qrgBvVIyUuY2dKGC9NYuPq7kxTF4vw+7cDDUs/FEJsFkL0CiF6u7q6GrVbJkPykM7lo0zK1ZKH8cBjg1ZZh0ZO49M//B2GRk4D8FfAUdv5WJNJFaLre/KYZ3ZMr+479O+V1e/PE0R80matHCeiBUKIY0S0AMCJEEIxxSAPASWfQNj63sUYPV1C/7HXsaZnvtHneu+2/soalv344aev8crmkPtP67s1HUPalMM4S7Ml9UGbxgvhz85DJlLRSKvIHwVwO4BvVV5/nloipjDkIZ3LZzLpnNWBubM7sGdwFDv7jwNAnaK4+9YeAP2VVzuqkjGNE+oYfCco/XuSOEuzqQsy33fbqlR9VtSx7r61pybY6kseDISiESdr5ccAbgIwD8BxAF8H8C8AtgK4FMArANYLIfSAaB2ctcI0mlCWtGkcPTvE9d0te4dhKwiy7SfrIN/YmRLu3LofuwZGvDJafMe6eUUXdg2MYOPq5dUGYGmPpd0zWGxZK96KPCSsyJk8kVY5+Kb1qfndSRRmlkos5Nh6+qMsaEo7SQD+57pVsSlyruxkGkoeLaqkrgWJryug3NlwErIgKO65SOo7Lj8JvARXj/OkbjLTMahjLfvAbIydKWFmx/QgrhJ2u5hhRc40hDjl5j4KLnQJ+VMvjmLXwAju3Lo/tjLvnNXh7BuubnfH2iuq/5fW5XhpoqroXPuNq8RM53zbgVex+c96sayrvjd6Enwml5CxlDzEZfIIdz/MIa2YfjV1w1NkWp+a1mbLAQ+Z+tY5qwP33baq6tNNMuaWvcP45i8PVnzgfsjURYBqjiX0ws8AYePqbiyZOxNDI2dw77b+6jZp99WsPHmmFrbIc0iz06+ycH+o1mTUmOq2Mh8aqG00FeoRW/Xnrlx0IVYuuihhep7QXv32K/clg4GA/fc3vT80chr/41+ew1ULL8TntHx1/ZyvW7UQ926rzcxJe635Wsj6Ocyji63QCCEa/vee97xHtDujp98UD/56UIyefjPyM9e2WfDgrwfFZXdtEw/+etC5XSPkGj39pti0fUBs2n7Q61zF/Vwe66ce+m31mPXv+JyPuOfCNaZtLPn+4Ik3qp9LufWxfOVp1LWlH6/t+FV5Gn3dFwEAfcKgU9kibxJxlvBqtIXua+1GFbLIbXytLlvgTPUrR8mgjxF17qSca3rm49qlx43Lq/nmqscp5nEtymyzcuX7avbL3bf2oDRRtshNeejjpUln6l9cn3Oo4iHbOQ2Rp9+O1j4r8iaR51atvjd3VCELEO9GDNGuNK4S1jMsTGPGVXY+xxG1KLML3WXyo89ca91mvDQR1AhIalTo59B2Tl0VqVnLWGhMZnrWf+xaaV3SPBoPnnhDfOqh34rBE28E22eWj+cuF4jLHZS1XHH3E0eWIrg7iiBjUmBxrXDWSo5ohWwVNRUPQKz2o9JKlWX0vrjWnZSf3bl1f9151c933C6JeuaMHA8oBy9dTbrSNtPyuVZ8XQx6lpBr3FAtZbO81tux7S27VnJEXtY2TEuj14/Uv6evDC9zxGV5vCyM0eW0Zcj47lcdT/0saWm+6zz6nGPf30HPEgp1DbqurbZ0f2QIK/Ic0YyqtahgYRKSHkfSYg890Ch7fQDlY7rvtlV1hTFqup/6qlZe2lDPkSqv7ruW51OVJ2qCsE0G+r5lkNQlp+/voJ6/kNdg3O6MTApM/pas/9hHnh90f+Km7QfFZXdtE5u2H0w1Thb47ENNJzT7rg+KTdsHUsnpm56pbvuJ7z/l9JmrMvqkS/rsWx9TTVtsBHn2VedZNhfg9EMGqH/crbeCSXv1w2Z9+Vr4cfpvu1LqTIVHQyOnq4Uwd6xdEVs2naQZRz77iHoq8dm3ngIqnwikiwkI585wncM8l9O3mmuHFTnaK+806gK+/bolNW4HX3zygn1dCvp2ujshbkqdvmiEbZ++14GuoEzf0/30oa4vH+Wop4DuGhjBsq5Z+MIty7Fy0YUYL01i7EypRpak98DU5FrfLybP91WruXZYkaP1ZmcXcfKq42D7nlr44rqxXQFL3W+8Ze9w3VqUtqZcY2dKWDxnJpbMnYkv3LLcuc8410FUUNKWTx+l1LOIUUhL/OnhMczsmI5v/vIgZnZMi1V0ZpNrKl99su4Yx0uT1eBx3u6rPD8tJIEVOVpvdnbRqAt4SrFOYtfACFYuehUHjr6GXQMjNa4RYKr606ZY1MlAZpZ85UNX1igUuf3G1ctrmjg90ncEDz/1MgDg6eExvPuyOTXyqYpJBjvHSxM1Fqtp26igpPreqfFSjfyuJw89UGvbv+l8q61q1bFlsNdVZBN1D9jkVoO68jed+i26YzXUijOJ5dnabwasyNEas3OzL2xbaby8mcdLE9VVYwARWf2pp8TJKkhbJsd4aQIbVy+v67ddVs4TAMhagSr32zmrAzM7plUs1ulG69qU3RG1rJou/3hpwujekNvdvKIrUlad8gQ3lZHjqqJ0lf/biPMkFzcuoB5DlHtNv75sx9NusCJvEeSFnXRxBElaXylgLo2X7pDbr7u8+p3x0iTWrbrEaAXHSYnbsncYDzw2iI2rlxutaFOAU/rb1/TMx/d+M1Q9XpOLZ7w0iY2ru+u2lZao+p4JOZa0yAHCA48dxtlzkzh8/A3cfWsPlnXNrj4RnC1NYMve4eqkZOrLov9OtgkrFHGMnaSGket3TtL/pp1gRd4irO+dKnx5pO9IYislVDGPejN/7zdDde4Q1VcLiKo1qSpdiS0ne0q++haytuPQ31cbUNkWJn7gscO4oXse+oZP4Ymh0ZptTcE++b6u3B/d//vKhFN+Stl9aKQyXjkIW95WYPPjL1XOUdmyNvVl0Y+j3Fys/tyZyKsLwzUBuK4vhhV5U8ji5pCLI+i+0LhySRdF3DFMN5areEX1R58tvVV515zyqAYyTa6H26+7vG4pMbsLZrImUBoV8FQnSAB43+Vzap4eTME+oNZdZPLfS0tb7w8uz8EN3fOsMurvZblkXJKJPavrO07MoO0wJZdn/dfuBUFJijoaQVK5bMUVUePJzzdtH/AqglG3S1LQocpj+77euEsWEd324N5qYU/cPtq2ZmAmGdIel88YWTfJiitPXPLUL7/RgAuC8kNe/Xu2FEC5GrpuAUWtw6n6hk1+ZN+gmMlqTrL2p+pr3rL3JTzw2CDGSxM1LgndjSGDiBtXL8ctV15ccy5UXAFFW8taPfPFlL3jQ9w0Sp88+DSkSetMMr6NdgqIsiJvAmn9e1EFKElvRl0uNYBqqgiUn294/1LcvKLLGoyLSl2Lg2ufOuUg6GGMlyZxx9orahSqrYJVd1lIV4yaDROnf7jLXaVn5iTNNdfPY9QEquNSeEmKpmzyhDJc0vTLb1VYkReQqAKUUL5EVSGoqX/659J3bQvGxb2hXKlmrn3qnC1N1rzqcsi8Z5mDffbcW4AQmNExHafGS9WK0HLKZLScNtkfeGywLu9d5dR4rbK3PS1IfNP0fIt8XA249EnB50lIJ8vAZFFbBISGFXlOiKN0o4JfEp9HS98bwaQsa4tBptcE3lTlEPeG0i1pU8623KfJ6pQySGZ0nFd3PGofbDUHWyKLl5Z1zbJmAvlMqK4l3fQnnqkJw93vxjdNL2oC9bk+5DmT2T16wVWziZr0siRPQVdW5DkhrT/PpCx9LOEkyt70aC8twCQWm74PPZ3QlrOtKhg1fz6qslBPY5Q52KpFvm7VJbh26fEan7gpd1s/v7rifnT/q9XKVtkDXY753iWduHlFV7V1gJww1q26BAeOvoZ1qy4xniuXj9s26UXFKKLwjWekIZliTNbkLQR58sGzIs8JcW4qHzeK700RR9nL0nqTsraVyPuiHpMpndB23Ot7F2P3oZHKwhEv4Y61K4zBtjU98/GPv3sF/cfeQM+Cd1St7TU986052PIpRL7qOeemyVMPbJ4tTQAAzpYmqvI/fvgk9gyexA3d87Bn8CRWLroIKxddiJWLLqpOiCa3kc2Vo+fW+zw5APGCno1wUyRRjEmbvIUgTz74IIqciIYBvAFgEsCEEKI3xLithN4LI+pGidNgCqi/CXxvCtVqu3/HIZhWsVH90rqyNlVKxlmqTX5fWqdSsZpkNh1356wO9C7prBTWUM0xAaix2KWyOzI2XlWiO/v9Fz+OunFNeeozOqZXX+V7o6ffxJ7Bk+hZcAFuXD6vOjFKP7ptPzZXjp5b7/PkoMtdfppqbpOrJIqxmX7wPPngQ1rkNwshTgYcr6Vw9cKwbW/rve3jRtFfoyx0WcFYlm+60XJTfeFyDN1K1f/vs99v/vIgOmedj7Ez55zBS9uN47LK1GDd8otfwfb+4xgeHcdH/ugS3Lh8XjCloVrGamBTlU09j3Nnv6267y17X8KGG5fW9F/xnchMLg9T1aeP3HGbXIUmT4qxaLBrpUFIP6xvLwzdCgbiBSxtqYS2ccryTcK1zFmcCWQqy6HWyjNZi1v7jmBo5AyWdc2KtBrjPvarn331j3vwuZu6MwlQuSxj2WNFjSnUti8YxM0rurBrYKSmg6Auo+k4bb+JqX9NlNxR64kyOcZUJRT3D8BLAJ4BsA/Ahqjt272yMw4+VXqfeui3kZVuja5yk9V3tz34hPjE95+qVjWaqvJk5eO+4TGrjD5VmULYjzPE8bvG2Dc8Jm75zi6xb3isTmbX72Nahi1quTqf42jn6sdWBhlXdl4vhHiViC4GsIOIDgohdqsbENEGABsA4NJLLw202+KSdDUaHZMlmGScpNgCcOOlCdzQPRd7BstNpqQv2uQeWNY1Gz/89DU1bhndKpVPNOOlSTz46yFsfvxFPH54BH/7sXcbA396F0jfALErjuEqjtm04xCGRs5g045D+L//9X01x2jLwwdQk/GjHqvaAE39PERuuL5/dmcUnyCKXAjxauX1BBH9DMA1AHZr22wGsBkAent7Rd0gbYbthoybgtWItDAXU778qe5/aln7ey6bA9WdZFJcpmMx+Xllx8QbuucCAPYMjtbkd8tJzZT77QoQS6WvxjEOHH2trh2wPoaaw9yz4ALsGTyJ0sSUn9uWh6+3PtDdT3oDNFtuuGkiMneHDEPSvOk85Vu3KqkVORHNAnCeEOKNyr//I4D/mVqyFidOZoILl2JsBFO+/EmjsnEFOAF7OpwrsFdOJTyC/mOvVy1PtRR+3apL6vqh2HzJuuU7XprAvpfLxUB3bt1fo8z1vOxT4+cAAGfPvYXP3dSNwydOe7UR1guBTEHGqHNhOo8+GSxpSJo3nad8axuFn2xM/pY4fwCWAni28vc8gK9FfYd95HZCdJvLknI3wAGxafvBGr+ur29adhPctH3A+xhNY+vHHHUO4vjOo+IOcl+f+P6T4rK7tomPb36y7hy4fkefmIBNRleswOVvDxE/SepPt8mZJ798I+6hEMeNrHzkQogXAbwr7ThMmTg+S5s/NLR1offvkG6AA0dfr6mONH1HLx7y6T2iYysEcr3Wuy/MFae2TBDVtWHLy5b9xFUrXM9f1/cHqEVDc61uNXncLteKLrte7Wo6L77nNyRxM6iaQZIc9rhkedycflhQXP7Q0BeMrkhkmuK6VQurQTxbQdIN3fNqCmRcvUdsyO+8d0lnXYm+xKYsptwXyXuE6L501Z119609NYHMqGCjqWhI3w+A6sQzXpqoLo9nOm51n/Jz9Rz7GAY+Siyk/70RSjMujQj6ZnncbaPIC+8D03Blq8S5YFznxdYA6461V1S3kUE8k2UsleiNy+dVx5bWKNDvvbbo1HdQo0jKWSbDcFWjyowRveI0qte6WpAFCNzQPQ+7Bkaw/sG92PxnvdjZf9yYYWPKbFEnH/lEY3oiUWX++s+fq7xLXplL6j5N/V3UgGhUzYFtHz7ZUT60a6ZMlsfdNoo8j49zaXAFFONcMK7zEuecmSzk+25bhS17h+uWRvNdW9RU/q9av3o16vrexTXpg2rGiO5uUNMUTVbmVBB3orrO5rHXz2Jo5Azu3VaehOR2NteHeg5lVo/LDSZ5dP/vsWdwtFKks8R6XkwWeJmppmP6b5i0W2Czs6MYN22jyJvxOJfmKSDqu6Fmd9d5SXvOyimD0yqLLE83pta5MCkdNY1PdfOoaY9AfRsEPbVRWpe6a0SVXW9LsG7Vwuoam6p7xdWyV8q47+Ux7BkcrbZckKiKFoBxbU+JzHPf9/Kpan7++t7FVT+9bG/gWsM0SbfAVnuabUXaRpE343EuzVOAq/jEdDMlvdl8y9uToroLbP5twOzqkMrm7Lm3jG1YdTePTB9U89bVIp91qy7xKpbRz6WUtXNWB3746Wuq2/kUGXXO6sDZ0iT2DI7ifZd3AhB1izvrsrjSNuVEJV0cJpeHaTKRMiXpFmg7zi17h3G2NIEZHdONBVRM42gbRd4MpGJRmyHp2BSw7REdME8KSSaNOJOD3IfL12rClElhmpj04CSAal9uCNQVHakyqMpfdxfoRT66G+X+HYdqFrGIcy5NgVvTd/uPvQ4AOH8a1VnLutJ1FYepE5VUnD4uD12muJOzabJR3VqAXyM4JjtYkWeIWo1ou9BdlrfNNWAiiRskzuQAwJiVopfC69iyOPR9r+mZj92HRrC4c2Z1yTNpbZ6bfAsbV3cDIKsMtmwKVfnJxSLUcySXgTs1PhWU9D2XpkWVTZP3Nz58dY1LxreS11Tdqk9UcbJSfNbwtAVDVat+alKZrFrkecpAaUdYkWdMXAXsU/UosbkAspJNf8/WD8SU/SGPR/Y9P1uaqLZuHRo5jXu39eOJoVE8MTRaXfKsNvOl7DpQ3QJ6ZorpOHTlp7fIlcvADZ04jYeffLkqZ9S5tC2qrE7egKha36pLRkUPhqpPG2qaYtKFQ6RMpqcilzz6NuZJ5QrjGOo5Yt96Y5h2zz33NHynmzdvvmfDhg0N329WjJ0p4eEnh7G0azZmKIEsAJjRMQ29Szrr3rd9vrRrdtXisX1H8vCTw5Ve3uXFFeISRzZ92xkd0/D+K7pqrDWTLPJ41vTMx5a9L+GhPS/hkX1Hse+V1/COGdPxk6eP4MjYOHYNjOD6ZXNx2dxZ2H34JDpndeDG5V01+5BjqzLIffW/+gdcdcmF1mMx/UZjZ0p49sjruHZpJz52zWX4/WtnMftt03DVQvs4koefHMZ3th/CLVdejBuXTy3OPHamhL7hU7h2aScAwne2H3L+PvL8nJsUNdvq48vf+u3nn4cDR19H3/BY5Ni2fbmuLXWbs+cmq+fsqksuxNvPn4Zzk29hxTvfEXl+5DlKc32mwXVPFplvfOMbx+65557N+vtskSsktSBCpjbGsaxVy7QZ1o9PPxDVGpS+6hu65+E9l11ULSgypRXafMgm9MWaTZgsXzWn+zeHTmDPYHkJtsMnTte5i3xX3VHH1J8gXOewnB1T/7Shv8qeNnELnJJ029QteD0DKWp/pieWRtFq6cZRsCJXSPrjx/FPh1K4+jiux2afop+4flOdKIVrCtQBZVeHuqK970TmWqxZ8szLp/Clnz6Lb//Ju6qBybPn3sIDj9XnfJcLi4A5M8835rjrMQHXccpX17GYsltcrjSTws/a2IiaVFzHI4PMcVoxhCRJzKjIsCJXSPrjhyrAiYM+jkt236IfXz+3DyZFZStAcY1vm0TU76xbtRAHjr6OdasW1nz3Sz99FkMjZ/DFrftx+bxya9vlF19gXBv09uuWVLNabNWyakwAgHcsI+4xu0iaEqpfH0kmZ9e+9euomdY40H7Vo+wjV4jyGYcgjg88zjgu2V37VD+L8nPHkdnkH7X5LW3jy/4e/+epV5y+d1kQM2dmBw4cfa06/r9feCH6Xh7Djd3z8PNnj+HmFV1YNGcGfvL0UcyZeT4OHH29uq3q879j7Yo65abHBK665MJUv6PrmNVzFPV/X/TrQ/19lnbNTu1P1q8jU/yASY/NRx5kqbe4f0VtYxuq/dOfEcQAABNYSURBVKZvu9Mk48oWsWor0yRyJdm3q1Wp2hrWtISZCX3Js6hl0f562/M1bWVNssl/b9o+kKptaVatWKPa88r/b9o+kGr/tta4aVHbHOepTW2rgIyXemsLsnCLAOZHdBtRecd68YssCdcDibZAns29EueYTGPInHB19R7TdrosAKqLF8u+J2reuMnP/cTQaM2iEKaiGz24GHVuXccb8hFeuiVGz5Rw/46BqruoPugZvSi3C59AdRLUYC+nHDYOVuQxCHXBu/Kzo7DlHcsx9OIX9YYHzBWStokljlJXqxxNueOm1XtMytBUhSgDuTJTY03PfKxc9CrGSxM4NT4VKJUNpp568SR2DYxgy97huiwWW8GVuu+nXhzF3bf2GLsiSrIKcMs8dBmQVTNE9EpbtUI0TRA9pD+53YKMeYEVeQxCXfCusuwo9DQ09fum4he16ZNE/656850aLzkVsk1ZqFWO+s0snxQ2ru6uUY6q8pcZJv/9j3vq0ur0TBAA1VQ4dXGLz35gGe5YewXu3yHw25dOQc9ikbK4JsLdh0awa2AEpYnn8MTQVGMqU7Vj3MCs+nSkPoGovWB29pfTMWXVpFoh6iq1z0u6XbsFGfMCK/KCYcs7tikRH7eCuo10gdgUsk2xuxYzsE0+qvLf2ncEQyNn8Ff/2o/H7rzJeMwqqttFr+pU+5nY8r9tE+FVCy/EE0OjWNY1G++/YqoxlY+StG1XTX0sTdQssiy/Y+oFc8faK6pPIrK9g8vaLYIl7Oofz6SDFXlBMKXzqZaiy9LUcVlNa3rm4/HDJzF6+k0AZoUk+6aohTgzO6bV9R3R5QZgLXp575LOqkWu9gOxLQChHoNedu8qarFNhJIZ55dL9ufMOr86hu+qRurkoh6DnLBWLrrI+LQxerqE/mOv4wu3LMe1S+dWJyA9hU89Llceel4Vptpoy6eoiPGHFXlBiPIn2yzNuD06dvYfr1Y4zp39tpqbTW8SdbY0AQB4cugkvvlfVtbIYZPblpd8aryEBRfOwHd/M1hxi/g3xlLxrcC0TWYma370TKm64s4da6+IfPrRJw+Ta0iVdUbHedgzOIobl4/VTECugpqo2oA8KsxyDGeqf3wU3KvFH1bkBcEVIJUXulTcgKj6VvUbPspN4LrZdBlmdJQvn98Nn8LO/uPW8dRX281577Z+7Bk8CQA1BTkuF4rK1GLPk3XuiziYrPnrl82tfFr2udtcTPKpQV3RSFrmJktajmMqt49ylbieEuIqzEZRjuG4G22p5MXvXwRaUpG34kxusiBt76mtc01uDPXVtB/bzabvb2oZMrvC0L9jU4JfuGU5zk0K9Cy4AJ+7qdsYVJQuFNPvO6UUu2uUok0ZmMawWfO1i13Uu1BMaZG2boOmpyjTdRoVNDS10FW/G0dh5pUi+P1zgym5POu/rAuCQhY45BVXQUraYhW1sMhU4BN3fNd3bQUutn2ohUXq7+va3vS+6RqJum6iZNeLsExFUZu2HxR/va2/5twmIauCJCbfoJ0Kgnxm8qJb7a7HTt8UMNkH/O5be7Csa8qqUzMpZMZEmiIml598yrc/4eUCklk1ej8U2zHr7z/z8il8cet+XLOks64XSNR1o68harKoVevY9DTywGODuHlFF3YNjFSfmKKKtEzXJ6f5MSotqch9LvK0/rdGTASufYR47JQ9SoD+mgWRyz7W2rUvXT76KEz+XL0wR3ZAtFU0qvuU/l8V39/jSz99FsOj4xgeHa8LJEZfN7ULF8dVpiZ/v2vCCu0fDnHNFt0AalVaUpH7kFYRZhmI8QnchbDI7r61B0DZItePJ2pJsTgW//KLZ9dMGKbiFldFoy6HqS+27+/x7T95F764dT/ev7zLa+kzYOr3WLfqklRl/SZ/v+06zMI/HOKa5QBkTjH5W7L+K2rTLJUQfmjb96d8rwcT+6J9t5FNjsp+2/pGR2l88dKX/YnvP1Xn1zZ917fhku27cX8P31iKa7u48Zhm+rZD7Jt9880F7eQjbwSu4gwfXJaNK5sh6ZhTVv5E1f8tfdHS6jflLEflK7usM9XinzOzo6YwCCj7xbfsHa4uNOHKtvBdn/TUeHQ2ikS6kNQyeBPSnTNemqjbLq7l3EyLNsRTHPvm80kQRU5EHwTwAIBpAP63EOJbIcYtCkluTpcCiLpZXIrJNqaas7xxdXdVeUXlHKtjRhXb6J8v65pds+iwKfAHTFV7uqoRo86xq3DI9l09VdN2zsuyiKq8d6xd4Wy+5cI3r94E+6cZG6kVORFNA/D3ANYCOArgaSJ6VAjRn3ZsE3m8mJP4M9NYNi7FZBtTlVF+XyovkxVsUlT37zhUszZmVI64aTy1aEYNqMpyf8DsH486x9K6PnvuLaxcdGHNdlHFM65xp6gNdCa1rOOcMx32TzM2Qljk1wAYFEK8CABE9BMAHwaQiSLP48Xc6MfNtIEwl5UtMZ9n89qYNrlMVYyqxVwbUC2Pef2yubGeUtR9yGCp7iKKKp7x+e1uv26Js8DKJpfeI0bHNI7tN1EnpDwaNEzzCKHIFwI4ovz/KID36RsR0QYAGwDg0ksvTbyzdqj2CpFDbFoM15TLbVu02XSeZR8SU7aH2qhJYqpitJXaqz1OfHrCuPaRNgPEtG+1BYJedu+SK6pHjOm3tBkr6oQExMvlZ1qbEIqcDO/VmWxCiM0ANgNAb2+v2aTzoB2CLWmeOqqNnk6/ic2Pv2QsXlG3k+4GXcm5zvOj+39fEzCVMuuuET1oa+tW6Cq518+Bq4zepvjjXjOqEpbpkvpnulw6UROXizgTUisbNIw/IRT5UQDq1bQIwKsBxs0NjX6MTfPUIRXNDd3zKu+Y5tnopdlc37E1edKDpr4K1CSLPsHYsm5sHQfTsL53cdWSVl1D63trl59z5aC7Jq4ofKtUW92gYfwJocifBrCciC4H8HsAfwrg4wHGzQ2N9suneepQFc2j+18FIGq6IEor0xT8jOpl7kqLVFMHVfeDz8RnkgVAjStDKvAbuudh4+ruTApo1AlbrXQ1uaVCThxx5GJ/OGMitSIXQkwQ0ecB/Arl9MOHhBDPp5YsR4T0y2d1U5qyTNQKyPW9U8uYfeH/PYP/9fF3V7eTx2VaNUfF1zd/59b9Xr3DTePq51p9CpA9Sm5cPi+1+8TE1IRWuySb6ff3XWwiBHkM8DP5IkgeuRDiFwB+EWKsPBLSL5/VTWkaVy14AYDeJZ14YmgUTwyN4pG+I9Xt5PENjZzGgaOv1SkndV1JWbzjksPU1MoX/VybFLtvab2UXS1Esh2HuiIPICLdTqZMmKwm6XYI8DPpKHxlZ9EeO31uyiTHZBpXL3iJ6h+uK6fani+13RBtsrrcL0kw+YXjuDX0ro2245AFStL3L59ibBOv6Xz7TtI+vdBd5yApRbtXGH8Kr8iL9tiZVWdG27i6YnUtOKBa8KpffePqbmxc3Q21G6JN1tBZRaZ8bNtkaFJU+rZ6V0f12OWry9UjMR2nr+Vs+n0bcR0X7V5h/Cm8Ii/SY6evRZT0mFz5zz6oFrxk4+pu5wK+WZ9/Wz62q0xfbqNTnshW1L0vP0saF4i7rSuNMMl5zPq6YvJP4RV5nvPKo4pybCQ5piRBRhOq5epa/DeNrCpRFZBx8rHTuDsaiemcZdGywWe/TGtQeEWeZ/QbLEuLKG2QUaJWMEo/cShchT+2CkhbPrZv9WU7WKHtcIyMG1bkGaLfYFlaRL5BRt/HcJNSNAXmAHgH0GyZNUDZ4l656NWaVrEuWeO0122WFdqo4CJb2gwr8gxp5A2WpIoy6fZ6JojveLbMGlPe+2c/YF+3EwjV0TBb8ujWYVoTVuQ5oBFFQnLcuIUsppJ0V7DORdRko4/rUsghOhpmTV4mFKb1Oa/ZAhQV6XKQiwanQVpuj/Qdid445bhSAe7sL3fQizoOqRR39h+vjiXf65zVUfPvtOhjucZe37u4rt+Lz/E0kpDnhmFcsEWekJCPzVlZbj5pbnG7+fnKmLV/2GZ1szuDaUdYkSekXDxjXscxLlm5AnzS3HwVtCv4aaJZCjXEpMgVkEzRYEWekHLxTG1wrogk7dUNhLXgQxGyeRbAVj1TDNpSkYeyuNoxmBXXgvch7u+RtcXcjr8rU2zaMtgZKrjYqGBW6ABemvGyOOa4v0dWwWGJeox5Cp4yjI22tMgbbXGltSCjHvV9x6/tZnjYOp4vPvv12Sbu79HI34/dLEwRaEtFrj/2Z/2onlYZRCku3/HVboam1L24uFYWijNpxPVr27bP4ndkNwtTBNpSketkbXX5KgObIopbSGMby7eM35epBluTNSvrrOmZj3u39WPXwEiwScOHLH7HvBQXMYwLVuQI2zbWRFbl867xTWOFVkq1DbamYbw0UdMA6+YVXXUtcLN8+uHUQ6ZdYUWO5AoulO9aEuoxXl22LOlTQBz0jolqy9k4za7SkmaimnIFTVRXEWJLnCkKrMhTEMp3LQllMavLliXpSJgUW8tZlbz6nNVFnhvlCmKYULAirxAnA0NdBCGPRTFx9mtqihXCpZDU398IGhE/YJhGwoq8go9lGrUIgk6zlFaS5cniLGjsQ57T9hoRP2CYRsKKvIKPFatar1HLjmVFVsG40E8PeXWhAPmWjWGSQEKIhu+0t7dX9PX1NXy/rYC0nL/yoSvZgmSYNoOI9gkhevX327JEv8jY+nC3G1w6HwY+j61BoRQ5X3S8WIEk634r7QKfx9agUD7yPAfQQtHKBSkhj033c7fyecsSjhe0BoWyyNvBrdBsCynLpx7fY/ORQX8yafZ5Kyr8hNcapLLIiegeAJ8BMFJ566tCiF+kFcpGO6SIZWUh+VqsWT71+B5bEhnYsmTamRCulfuFEN8JMA6D7CarPKzs43tsSWRoh0meYWwUyrXCJMfXLdWMR23dlWKSoRmBbg6uM0UhhCL/PBEdIKKHiGiObSMi2kBEfUTUNzIyYtuMyYg8+0J9/NvN8IGz350pCpGuFSLaCeCdho++BuC7AP4KgKi83gfgz03jCCE2A9gMlAuCEsrLtCBxqmob6QNnvztTFIJVdhLREgDbhBBXR23LlZ3ZwCl4DNPa2Co702atLBBCHKv896MAnkszHpMM7qXNMO1NWh/53xDRvxHRAQA3A7gjgExMDMbOlHDn1v2VjBRq+Tx7hmHqSWWRCyE+GUqQVqHR7o1H+o4oy6otYZcKw7QhhSrRLwKNbiPACyIwDMOKPDCc6cAwTKPhgqDAuPK1sygw4VxnhmHYIm8gWbhd+AmAYRhW5A0kC6XLPUYYhmFF3kBY6TIMkwXsI2cYhik4rMgZhmEKDityhmGYgsOKnGEYpuCwImcYhik4rMgZhmEKDityhmGYgsOKnGEYpuCwImcYhik4rMiZtiOL5mUM00xYkecYVjjZwB0jmVaDe63kmNDdEnlx5jLcMZJpNViR55jQCqfRqxflFW5exrQarMhzTGiFw5Yow7QmrMjbCLZEGaY14WAnwzBMwWFFzjAMU3BYkTMMwxQcVuQMwzAFhxU5wzBMwWFFzjAMU3BYkTMMwxQcEkI0fqdEIwBebviOs2cegJPNFiID+LiKBR9X8fA9tsuEEF36m01R5K0KEfUJIXqbLUdo+LiKBR9X8Uh7bOxaYRiGKTisyBmGYQoOK/KwbG62ABnBx1Us+LiKR6pjYx85wzBMwWGLnGEYpuCwIg8IEX2biA4S0QEi+hkRXdRsmUJBROuJ6HkieouICp05QEQfJKIBIhokoi83W55QENFDRHSCiJ5rtiwhIaLFRLSLiF6oXIMbmy1TCIjo7UT0OyJ6tnJc30g6FivysOwAcLUQYiWAQwC+0mR5QvIcgP8MYHezBUkDEU0D8PcAPgSgB8DHiKinuVIF4x8AfLDZQmTABIA7hRD/DsC1AP5bi/xmbwK4RQjxLgCrAHyQiK5NMhAr8oAIIbYLISYq/30KwKJmyhMSIcQLQoiBZssRgGsADAohXhRClAD8BMCHmyxTEIQQuwGMNVuO0Aghjgkhnqn8+w0ALwBY2Fyp0iPKnK789/zKX6KgJSvy7PhzAL9sthBMHQsBHFH+fxQtoBTaBSJaAuCPAPy2uZKEgYimEdF+ACcA7BBCJDouXuotJkS0E8A7DR99TQjx88o2X0P5cfBHjZQtLT7H1gKQ4T1O3SoARDQbwD8B+EshxB+aLU8IhBCTAFZV4mk/I6KrhRCxYxysyGMihFjj+pyIbgdwK4DVomC5nVHH1iIcBaCuPr0IwKtNkoXxhIjOR1mJ/0gI8c/Nlic0QojXiOjXKMc4Yitydq0EhIg+COAuAOuEEOPNlocx8jSA5UR0ORF1APhTAI82WSbGARERgB8AeEEIsanZ8oSCiLpkZhsRzQCwBsDBJGOxIg/L3wG4AMAOItpPRA82W6BQENFHiegogP8A4F+J6FfNlikJlWD05wH8CuWg2VYhxPPNlSoMRPRjAE8CWEFER4noL5otUyCuB/BJALdU7qv9RPSfmi1UABYA2EVEB1A2MHYIIbYlGYgrOxmGYQoOW+QMwzAFhxU5wzBMwWFFzjAMU3BYkTMMwxQcVuQMwzAFhxU5wzBMwWFFzjAMU3BYkTMMwxSc/w91ZSo0u6ZqdwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(features[:,0].numpy(),labels.numpy(),1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_iter(batch_size,features,labels):\n",
    "    num_examples=len(features)\n",
    "    indices=list(range(num_examples))\n",
    "    random.shuffle(indices) # 随机打算indices的顺序\n",
    "    # i取值0~num_examples，步长为batch_size\n",
    "    for i in range(0,num_examples,batch_size):\n",
    "        # j每次取出一个batch_size的tensor，最后一次可能不够一个batch_size\n",
    "        j=torch.LongTensor(indices[i:min(i+batch_size,num_examples)])\n",
    "        # yield是一个生成器，index_select在指定维度取j指定的相应项返回新的tensor\n",
    "        yield features.index_select(0,j),labels.index_select(0,j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE=10\n",
    "\n",
    "# for X,y in data_iter(batch_size,features,labels):\n",
    "#     print(X,'\\n',y)\n",
    "#     break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.0159],\n",
       "         [-0.0008]], requires_grad=True),\n",
       " '\\n',\n",
       " tensor([0.], requires_grad=True))"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# w_learn,b_learn是要通过训练学习的参数,在训练前先随机初始化\n",
    "w_learn=torch.tensor(np.random.normal(0,0.01,(num_inputs,1)),dtype=torch.float32)\n",
    "b_learn=torch.zeros(1,dtype=torch.float32)\n",
    "\n",
    "w_learn.requires_grad_(requires_grad=True)\n",
    "b_learn.requires_grad_(requires_grad=True)\n",
    "\n",
    "w_learn,b_learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义模型\n",
    "定义用来训练参数的训练模型：  \n",
    "\n",
    "$$price=w_{area}*area+w_{age}*age+b$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linearRegression(X,w,b):\n",
    "    return torch.mm(X,w)+b # torch.mm 矩阵相乘"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义损失函数\n",
    "使用均方误差损失函数：  \n",
    "\n",
    "$$l^{(i)}(w,b)=\\frac{1}{2}(\\hat{y}^{(i)}-y^{(i)})^2$$  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def squared_loss(y_hat,y):\n",
    "    return (y_hat-y.view(y_hat.size()))**2/2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义优化函数\n",
    "小批量随机梯度下降：  \n",
    "\n",
    "$$(w,b)\\leftarrow(w,b)-\\frac{\\eta}{|\\beta|}\\sum_{i\\in\\beta}\\partial_{(w,b)}l^{(i)}(w,b)$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sgd(params,lr,batch_size):\n",
    "    for param in params:\n",
    "        param.data-=lr*param.grad/batch_size # 使用数据操作参数没有梯度跟踪"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1,loss: 0.041578203439712524\n",
      "epoch: 2,loss: 0.00015552206605207175\n",
      "epoch: 3,loss: 4.943608655594289e-05\n",
      "epoch: 4,loss: 4.8990626964950934e-05\n",
      "epoch: 5,loss: 4.8957488615997136e-05\n"
     ]
    }
   ],
   "source": [
    "# 超参数初始化\n",
    "LR=0.03\n",
    "NUM_EPOCHS=5\n",
    "\n",
    "# 训练\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    # X是特征向量，y是标签，尺寸是一个batch_size\n",
    "    for X,y in data_iter(batch_size,features,labels):\n",
    "        y_hat=linearRegression(X,w_learn,b_learn)\n",
    "        loss=squared_loss(y_hat,y).sum()\n",
    "        loss.backward() # 计算批量样本损失的梯度\n",
    "        sgd([w_learn,b_learn],LR,BATCH_SIZE) # 使用小批量随机梯度下降来迭代模型参数\n",
    "        # 重置参数梯度\n",
    "        w_learn.grad.data.zero_()\n",
    "        b_learn.grad.data.zero_()\n",
    "    train_loss=squared_loss(linearRegression(features,w_learn,b_learn),labels)\n",
    "    print(\"epoch: {},loss: {}\".format(epoch+1,train_loss.mean().item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([2, -3.4], tensor([[ 2.0000],\n",
       "         [-3.3999]], requires_grad=True), 4.2, tensor([4.2006], requires_grad=True))"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w,w_learn,b,b_learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 线性回归模型使用PyTorch的简洁实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "torch.manual_seed(1) #每次初始化相同，使得实验可以复现\n",
    "\n",
    "torch.set_default_tensor_type(\"torch.FloatTensor\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 生成数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_INPUTS=2\n",
    "NUM_EXAMPLES=1000\n",
    "\n",
    "W=[2,-3.4]\n",
    "B=4.2\n",
    "\n",
    "features=torch.tensor(np.random.normal(0,1,(NUM_EXAMPELS,NUM_INPUTS)),dtype=torch.float)\n",
    "lables=W[0]*features[:,0]+W[1]*features[:,1]+B\n",
    "lables+=torch.tensor(np.random.normal(0,0.01,size=lables.size()),dtype=torch.float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data as Data\n",
    "\n",
    "BATCH_SIZE=10\n",
    "\n",
    "# 组合数据集的特征和标签\n",
    "dataset=Data.TensorDataset(features,labels)\n",
    "\n",
    "# 将数据集放入dataloader\n",
    "data_iter=Data.DataLoader(\n",
    "    dataset=dataset,          # torch TensorDataset format\n",
    "    batch_size=BATCH_SIZE,    # 批量大小\n",
    "    shuffle=True,             # 是否打乱顺利\n",
    "    num_workers=0,             # 多线程\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.3894, -0.2153],\n",
      "        [ 1.0667, -0.9417],\n",
      "        [-0.1372, -0.0412],\n",
      "        [ 0.9436,  2.0076],\n",
      "        [-0.8629, -1.1596],\n",
      "        [ 0.9843,  0.5790],\n",
      "        [ 1.0678, -1.6272],\n",
      "        [ 1.1646, -0.9377],\n",
      "        [-0.9371,  0.9893],\n",
      "        [ 1.1103,  1.8677]]) \n",
      " tensor([-0.1295,  3.2085,  8.5177,  3.6092,  3.7359,  6.3555, -0.1464, 15.5801,\n",
      "         5.2434, -2.2822])\n"
     ]
    }
   ],
   "source": [
    "for X,y in data_iter:\n",
    "    print(X,'\\n',y)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearNet(nn.Module):\n",
    "    def __init__(self,n_feature):\n",
    "        super(LinearNet,self).__init__()  # 调用父类\n",
    "        self.linear=nn.Linear(n_feature,1) # 线性回归模型结果输出一维的预测值，原型torch.nnLinear(in_features,out_features,bias=True)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        y=self.linear(x)\n",
    "        return y\n",
    "\n",
    "net=LinearNet(NUM_INPUTS)\n",
    "# print(net[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "初始化多层网络的3种方法  \n",
    "- 方法1  \n",
    "```python\n",
    "net=nn.Sequential(nn.Linear(NUM_INPUTS,1)\n",
    "                  # 其他层可以在这里添加\n",
    "                 )\n",
    "```\n",
    "- 方法2  \n",
    "```python\n",
    "net=nn.Sequential()\n",
    "net.add_module('linear',nn.Linear(NUM_INPUTS,1))\n",
    "#net.add_module......\n",
    "```\n",
    "- 方法3  \n",
    "```python\n",
    "from collections import OrderedDict\n",
    "net=nn.Sequential(OrderedDict([\n",
    "    ('linear',nn.Linear(NUM_INPUTS,1))\n",
    "    #......\n",
    "]))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Linear(in_features=2, out_features=1, bias=True)\n",
      ")\n",
      "Linear(in_features=2, out_features=1, bias=True)\n"
     ]
    }
   ],
   "source": [
    "# method one\n",
    "net = nn.Sequential(\n",
    "    nn.Linear(num_inputs, 1)\n",
    "    # other layers can be added here\n",
    "    )\n",
    "\n",
    "print(net)\n",
    "print(net[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([0.], requires_grad=True)"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.nn import init\n",
    "\n",
    "init.normal_(net[0].weight,mean=0.0,std=0.01) #n维的torch.Tensor,正态分布的均值，标准差\n",
    "init.constant_(net[0].bias,val=0.0) #n维的torch.Tensor，用来填充的张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.0034,  0.0019]], requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([0.], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "for param in net.parameters():\n",
    "    print(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss=nn.MSELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义优化函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SGD (\n",
      "Parameter Group 0\n",
      "    dampening: 0\n",
      "    lr: 0.03\n",
      "    momentum: 0\n",
      "    nesterov: False\n",
      "    weight_decay: 0\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "optimizer=optim.SGD(net.parameters(),lr=0.03)\n",
    "print(optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1,loss:7.436114311218262\n",
      "epoch:2,loss:6.9896955490112305\n",
      "epoch:3,loss:21.924951553344727\n",
      "epoch:4,loss:11.167420387268066\n",
      "epoch:5,loss:14.600977897644043\n",
      "epoch:6,loss:11.437329292297363\n",
      "epoch:7,loss:10.282097816467285\n",
      "epoch:8,loss:11.784852981567383\n",
      "epoch:9,loss:8.335785865783691\n",
      "epoch:10,loss:12.937174797058105\n"
     ]
    }
   ],
   "source": [
    "NUM_EPOCHS=10\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    for X,y in data_iter:\n",
    "        y_hat=net(X)\n",
    "        loss_f=loss(y_hat,y.view(-1,1))\n",
    "        optimizer.zero_grad()\n",
    "        loss_f.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "    print(\"epoch:{},loss:{}\".format(epoch+1,loss_f.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, -3.4] tensor([[-0.0845,  0.1310]])\n",
      "4.2 tensor([4.5577])\n"
     ]
    }
   ],
   "source": [
    "dense=net[0]\n",
    "print(W,dense.weight.data)\n",
    "print(B,dense.bias.data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
